import tensorflow as tf
from tensorflow import keras

layers = keras.layers

x = tf.random.normal([2, 10])
print(x)

y_onehot = tf.constant([1, 3])
print(y_onehot)
y_onehot = tf.one_hot(y_onehot, depth=10)
print(y_onehot)
loss = keras.losses.MSE(y_onehot, x)
print(loss)
loss = tf.reduce_mean(loss)
print(loss)
